{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "*Copyright (c) Microsoft Corporation. All rights reserved.*  \n",
    "*Licensed under the MIT License.*"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Named Entity Recognition Using Transformer Model"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Summary\n",
    "\n",
    "This notebook demonstrates how to fine tune [pretrained Transformer model](https://github.com/huggingface/transformers) for named entity recognition (NER) task. Utility functions and classes in the NLP Best Practices repo are used to facilitate data preprocessing, model training, model scoring, and model evaluation. \n",
    "\n",
    "The pretrained transformer of [BERT (Bidirectional Transformers for Language Understanding)](https://arxiv.org/pdf/1810.04805.pdf) architecture is used in this notebook. [BERT](https://arxiv.org/pdf/1810.04805.pdf) is a powerful pre-trained lanaguage model that can be used for multiple NLP tasks, including text classification, question answering, named entity recognition, etc. It's able to achieve state of the art performance with only a few epochs of fine tuning on task specific datasets.\n",
    "\n",
    "The figure below illustrates how BERT can be fine tuned for NER tasks. The input data is a list of tokens representing a sentence. In the training data, each token has an entity label. After fine tuning, the model predicts an entity label for each token in a given testing sentence. \n",
    "\n",
    "<img src=\"https://nlpbp.blob.core.windows.net/images/bert_architecture.png\">"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import random\n",
    "import string\n",
    "import sys\n",
    "from tempfile import TemporaryDirectory\n",
    "\n",
    "import pandas as pd\n",
    "import scrapbook as sb\n",
    "import torch\n",
    "from seqeval.metrics import classification_report\n",
    "from sklearn.model_selection import train_test_split\n",
    "from utils_nlp.common.pytorch_utils import dataloader_from_dataset\n",
    "from utils_nlp.common.timer import Timer\n",
    "from utils_nlp.dataset import wikigold\n",
    "from utils_nlp.dataset.ner_utils import read_conll_file\n",
    "from utils_nlp.dataset.url_utils import maybe_download\n",
    "from utils_nlp.models.transformers.named_entity_recognition import (\n",
    "    TokenClassificationProcessor, TokenClassifier)\n",
    "from utils_nlp.models.transformers.named_entity_recognition import supported_models as SUPPORTED_MODELS"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Configuration"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The running time shown in this notebook is on a Standard_NC12 Azure Virtual Machine with 2 NVIDIA Tesla K80 GPUs. \n",
    "> **Tip**: If you want to run through the notebook quickly, you can set the **`QUICK_RUN`** flag in the cell below to **`True`** to run the notebook on a small subset of the data and a smaller number of epochs. \n",
    "\n",
    "The table below provides some reference running time on different machine configurations.  \n",
    "\n",
    "|QUICK_RUN|Machine Configurations|Running time|\n",
    "|:---------|:----------------------|:------------|\n",
    "|True|4 CPUs, 14GB memory| ~ 2 minutes|\n",
    "|False|4 CPUs, 14GB memory| ~1.5 hours|\n",
    "|True|1 NVIDIA Tesla K80 GPUs, 12GB GPU memory| ~ 1 minute|\n",
    "|False|1 NVIDIA Tesla K80 GPUs, 12GB GPU memory| ~ 7 minutes |\n",
    "\n",
    "If you run into CUDA out-of-memory error or the jupyter kernel dies constantly, try reducing the `BATCH_SIZE` and `MAX_SEQ_LENGTH`, but note that model performance will be compromised. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Set QUICK_RUN = True to run the notebook on a small subset of data and a smaller number of epochs.\n",
    "QUICK_RUN = False"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "tags": [
     "parameters"
    ]
   },
   "outputs": [],
   "source": [
    "# Wikigold dataset\n",
    "DATA_URL = (\n",
    "    \"https://raw.githubusercontent.com/juand-r/entity-recognition-datasets\"\n",
    "    \"/master/data/wikigold/CONLL-format/data/wikigold.conll.txt\"\n",
    ")\n",
    "\n",
    "# fraction of the dataset used for testing\n",
    "TEST_DATA_FRACTION = 0.3\n",
    "\n",
    "# sub-sampling ratio\n",
    "SAMPLE_RATIO = 1\n",
    "\n",
    "# the data path used to save the downloaded data file\n",
    "DATA_PATH = TemporaryDirectory().name\n",
    "\n",
    "# the cache data path during find tuning\n",
    "CACHE_DIR = TemporaryDirectory().name\n",
    "\n",
    "# set random seeds\n",
    "RANDOM_SEED = 100\n",
    "torch.manual_seed(RANDOM_SEED)\n",
    "\n",
    "# model configurations\n",
    "NUM_TRAIN_EPOCHS = 5\n",
    "MODEL_NAME = \"distilbert-base-cased\"\n",
    "DO_LOWER_CASE = False\n",
    "MAX_SEQ_LENGTH = 200\n",
    "TRAILING_PIECE_TAG = \"X\"\n",
    "NUM_GPUS = None  # uses all if available\n",
    "BATCH_SIZE = 16\n",
    "\n",
    "# update variables for quick run option\n",
    "if QUICK_RUN:\n",
    "    SAMPLE_RATIO = 0.1\n",
    "    NUM_TRAIN_EPOCHS = 1"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Models that can be used for token classification task"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>supported models</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>albert-base-v1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>albert-base-v2</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>albert-large-v1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>albert-large-v2</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>albert-xlarge-v1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>...</th>\n",
       "      <td>...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>65</th>\n",
       "      <td>xlm-roberta-large-finetuned-conll02-spanish</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>66</th>\n",
       "      <td>xlm-roberta-large-finetuned-conll03-english</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>67</th>\n",
       "      <td>xlm-roberta-large-finetuned-conll03-german</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>68</th>\n",
       "      <td>xlnet-base-cased</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>69</th>\n",
       "      <td>xlnet-large-cased</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>70 rows × 1 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "                               supported models\n",
       "0                                albert-base-v1\n",
       "1                                albert-base-v2\n",
       "2                               albert-large-v1\n",
       "3                               albert-large-v2\n",
       "4                              albert-xlarge-v1\n",
       "..                                          ...\n",
       "65  xlm-roberta-large-finetuned-conll02-spanish\n",
       "66  xlm-roberta-large-finetuned-conll03-english\n",
       "67   xlm-roberta-large-finetuned-conll03-german\n",
       "68                             xlnet-base-cased\n",
       "69                            xlnet-large-cased\n",
       "\n",
       "[70 rows x 1 columns]"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "pd.DataFrame({\"supported models\": SUPPORTED_MODELS})"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Get Traning & Testing Dataset\n",
    "\n",
    "The dataset used in this notebook is the [wikigold dataset](https://www.aclweb.org/anthology/W09-3302). The wikigold dataset consists of 145 mannually labelled Wikipedia articles, including 1841 sentences and 40k tokens in total. The dataset can be directly downloaded from [here](https://github.com/juand-r/entity-recognition-datasets/tree/master/data/wikigold). \n",
    "\n",
    "In the following cell, we download the data file, parse the tokens and labels, sample a given number of sentences, and split the dataset for training and testing."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 96.0/96.0 [00:00<00:00, 4.02kKB/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Maximum sequence length is: 144\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "# download data\n",
    "file_name = DATA_URL.split(\"/\")[-1]  # a name for the downloaded file\n",
    "maybe_download(DATA_URL, file_name, DATA_PATH)\n",
    "data_file = os.path.join(DATA_PATH, file_name)\n",
    "\n",
    "# parse CoNll file\n",
    "sentence_list, labels_list = read_conll_file(data_file, sep=\" \")\n",
    "\n",
    "# sub-sample (optional)\n",
    "random.seed(RANDOM_SEED)\n",
    "sample_size = int(SAMPLE_RATIO * len(sentence_list))\n",
    "sentence_list, labels_list = list(\n",
    "    zip(*random.sample(list(zip(sentence_list, labels_list)), k=sample_size))\n",
    ")\n",
    "\n",
    "# train-test split\n",
    "train_sentence_list, test_sentence_list, train_labels_list, test_labels_list = train_test_split(\n",
    "    sentence_list, labels_list, test_size=TEST_DATA_FRACTION, random_state=RANDOM_SEED\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The following is an example input sentence of the training set."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>sentence</th>\n",
       "      <th>labels</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>[The, origin, of, Agotes, (, or, Cagots, ), is...</td>\n",
       "      <td>[O, O, O, I-MISC, O, O, I-MISC, O, O, O, O]</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>[-DOCSTART-]</td>\n",
       "      <td>[O]</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>[It, provides, full, -, and, part-time, polyte...</td>\n",
       "      <td>[O, O, O, O, O, O, O, O, O, O, O, O, O, O, O, ...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>[Since, she, was, the, daughter, of, the, grea...</td>\n",
       "      <td>[O, O, O, O, O, O, O, O, I-MISC, O, O, O, I-MI...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>[The, goals, were, two, posts, ,, with, no, cr...</td>\n",
       "      <td>[O, O, O, O, O, O, O, O, O, O]</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5</th>\n",
       "      <td>[At, one, point, ,, so, many, orders, had, bee...</td>\n",
       "      <td>[O, O, O, O, O, O, O, O, O, O, O, O, O, O, O, ...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>6</th>\n",
       "      <td>[Left, camp, in, July, 1972, ,, and, was, deal...</td>\n",
       "      <td>[O, O, O, O, O, O, O, O, O, O, O, I-ORG, I-ORG...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>7</th>\n",
       "      <td>[She, fled, again, to, Abra, ,, where, she, wa...</td>\n",
       "      <td>[O, O, O, O, I-LOC, O, O, O, O, O, O]</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>8</th>\n",
       "      <td>[As, the, younger, sibling, ,, Ben, was, const...</td>\n",
       "      <td>[O, O, O, O, O, I-PER, O, O, O, O, O, O, O, O,...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>9</th>\n",
       "      <td>[Milepost, 1, :, granite, masonry, arch, over,...</td>\n",
       "      <td>[O, O, O, O, O, O, O, I-LOC, I-LOC, O]</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                                            sentence  \\\n",
       "0  [The, origin, of, Agotes, (, or, Cagots, ), is...   \n",
       "1                                       [-DOCSTART-]   \n",
       "2  [It, provides, full, -, and, part-time, polyte...   \n",
       "3  [Since, she, was, the, daughter, of, the, grea...   \n",
       "4  [The, goals, were, two, posts, ,, with, no, cr...   \n",
       "5  [At, one, point, ,, so, many, orders, had, bee...   \n",
       "6  [Left, camp, in, July, 1972, ,, and, was, deal...   \n",
       "7  [She, fled, again, to, Abra, ,, where, she, wa...   \n",
       "8  [As, the, younger, sibling, ,, Ben, was, const...   \n",
       "9  [Milepost, 1, :, granite, masonry, arch, over,...   \n",
       "\n",
       "                                              labels  \n",
       "0        [O, O, O, I-MISC, O, O, I-MISC, O, O, O, O]  \n",
       "1                                                [O]  \n",
       "2  [O, O, O, O, O, O, O, O, O, O, O, O, O, O, O, ...  \n",
       "3  [O, O, O, O, O, O, O, O, I-MISC, O, O, O, I-MI...  \n",
       "4                     [O, O, O, O, O, O, O, O, O, O]  \n",
       "5  [O, O, O, O, O, O, O, O, O, O, O, O, O, O, O, ...  \n",
       "6  [O, O, O, O, O, O, O, O, O, O, O, I-ORG, I-ORG...  \n",
       "7              [O, O, O, O, I-LOC, O, O, O, O, O, O]  \n",
       "8  [O, O, O, O, O, I-PER, O, O, O, O, O, O, O, O,...  \n",
       "9             [O, O, O, O, O, O, O, I-LOC, I-LOC, O]  "
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Show example sentences from input\n",
    "pd.DataFrame({\"sentence\": sentence_list, \"labels\": labels_list}).head(10)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>token</th>\n",
       "      <th>label</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>In</td>\n",
       "      <td>O</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>1999</td>\n",
       "      <td>O</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>,</td>\n",
       "      <td>O</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>the</td>\n",
       "      <td>O</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>Caloi</td>\n",
       "      <td>I-PER</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5</th>\n",
       "      <td>family</td>\n",
       "      <td>O</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>6</th>\n",
       "      <td>sold</td>\n",
       "      <td>O</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>7</th>\n",
       "      <td>the</td>\n",
       "      <td>O</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>8</th>\n",
       "      <td>majority</td>\n",
       "      <td>O</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>9</th>\n",
       "      <td>of</td>\n",
       "      <td>O</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>10</th>\n",
       "      <td>Caloi</td>\n",
       "      <td>I-ORG</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "       token  label\n",
       "0         In      O\n",
       "1       1999      O\n",
       "2          ,      O\n",
       "3        the      O\n",
       "4      Caloi  I-PER\n",
       "5     family      O\n",
       "6       sold      O\n",
       "7        the      O\n",
       "8   majority      O\n",
       "9         of      O\n",
       "10     Caloi  I-ORG"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Show example tokens from input\n",
    "pd.DataFrame({\"token\": train_sentence_list[0], \"label\": train_labels_list[0]}).head(11)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "> If your data is unlabeled, try using an annotation tool to simplify the process of labeling. The example [here](../annotation/Doccano.md) introduces [Doccanno](https://github.com/chakki-works/doccano) and shows how it can be used for NER annotation."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Create PyTorch Datasets and Dataloaders\n",
    "Given the tokenized input and corresponding labels, we use a custom processer to convert our input lists into a PyTorch dataset that can be used with our token classifier. Next, we create PyTorch dataloaders for training and testing."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "ea57217fe6394812af03defcdaffe4db",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=0, description='Downloading', max=411, style=ProgressStyle(description_width=…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "00884141779a4ddead34204d5ea01b41",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=0, description='Downloading', max=213450, style=ProgressStyle(description_wid…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING:root:Token lists with length > 512 will be truncated\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING:root:Token lists with length > 512 will be truncated\n"
     ]
    }
   ],
   "source": [
    "processor = TokenClassificationProcessor(model_name=MODEL_NAME, to_lower=DO_LOWER_CASE, cache_dir=CACHE_DIR)\n",
    "\n",
    "label_map = TokenClassificationProcessor.create_label_map(\n",
    "    label_lists=labels_list, trailing_piece_tag=TRAILING_PIECE_TAG\n",
    ")\n",
    "\n",
    "train_dataset = processor.preprocess(\n",
    "    text=train_sentence_list,\n",
    "    max_len=MAX_SEQ_LENGTH,\n",
    "    labels=train_labels_list,\n",
    "    label_map=label_map,\n",
    "    trailing_piece_tag=TRAILING_PIECE_TAG,\n",
    ")\n",
    "train_dataloader = dataloader_from_dataset(\n",
    "    train_dataset, batch_size=BATCH_SIZE, num_gpus=NUM_GPUS, shuffle=True, distributed=False\n",
    ")\n",
    "\n",
    "test_dataset = processor.preprocess(\n",
    "    text=test_sentence_list,\n",
    "    max_len=MAX_SEQ_LENGTH,\n",
    "    labels=test_labels_list,\n",
    "    label_map=label_map,\n",
    "    trailing_piece_tag=TRAILING_PIECE_TAG,\n",
    ")\n",
    "test_dataloader = dataloader_from_dataset(\n",
    "    test_dataset, batch_size=BATCH_SIZE, num_gpus=NUM_GPUS, shuffle=False, distributed=False\n",
    ")\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Train Model\n",
    "\n",
    "There are two steps to train a NER model using pretrained transformer model: 1) Instantiate a TokenClassifier class which is a wrapper of a transformer-based network, and 2) Fit the model using the preprocessed training dataloader. The member method `fit` of TokenClassifier class is used to fine-tune the model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "7cd3a9259b5c42638e8580f9fbae27db",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=0, description='Downloading', max=263273408, style=ProgressStyle(description_…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Training time : 0.060 hrs\n"
     ]
    }
   ],
   "source": [
    "# Instantiate a TokenClassifier class for NER using pretrained transformer model\n",
    "model = TokenClassifier(\n",
    "    model_name=MODEL_NAME,\n",
    "    num_labels=len(label_map),\n",
    "    cache_dir=CACHE_DIR\n",
    ")\n",
    "\n",
    "# Fine tune the model using the training dataset\n",
    "with Timer() as t:\n",
    "    model.fit(\n",
    "        train_dataloader=train_dataloader,\n",
    "        num_epochs=NUM_TRAIN_EPOCHS,\n",
    "        num_gpus=NUM_GPUS,\n",
    "        local_rank=-1,\n",
    "        weight_decay=0.0,\n",
    "        learning_rate=5e-5,\n",
    "        adam_epsilon=1e-8,\n",
    "        warmup_steps=0,\n",
    "        verbose=False,\n",
    "        seed=RANDOM_SEED\n",
    "    )\n",
    "\n",
    "print(\"Training time : {:.3f} hrs\".format(t.interval / 3600))\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Evaluate on Testing Dataset\n",
    "\n",
    "The `predict` method of the TokenClassifier returns a Numpy ndarray of raw predictions. The shape of the ndarray is \\[`number_of_examples`, `sequence_length`, `number_of_labels`\\]. Each value in the ndarray is not normalized. Post-process will be needed to get the probability for each class label. Function `get_predicted_token_labels` will process the raw prediction and output the predicted labels for each token."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Scoring: 100%|██████████| 35/35 [00:06<00:00,  6.14it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Prediction time : 0.002 hrs\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "with Timer() as t:\n",
    "    preds = model.predict(\n",
    "        test_dataloader=test_dataloader,\n",
    "        num_gpus=None,\n",
    "        verbose=True\n",
    "    )\n",
    "\n",
    "print(\"Prediction time : {:.3f} hrs\".format(t.interval / 3600))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Get the true token labels of the testing dataset:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "true_labels = model.get_true_test_labels(label_map=label_map, dataset=test_dataset)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Get the predicted labels for each token by calling member method `get_predicted_token_labels`, and generate the classification report."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "           precision    recall  f1-score   support\n",
      "\n",
      "      ORG       0.72      0.76      0.74       274\n",
      "     MISC       0.67      0.73      0.70       221\n",
      "      LOC       0.79      0.84      0.81       317\n",
      "      PER       0.90      0.93      0.92       257\n",
      "\n",
      "micro avg       0.76      0.82      0.79      1069\n",
      "macro avg       0.77      0.82      0.79      1069\n",
      "\n"
     ]
    }
   ],
   "source": [
    "predicted_labels = model.get_predicted_token_labels(\n",
    "    predictions=preds,\n",
    "    label_map=label_map,\n",
    "    dataset=test_dataset\n",
    ")\n",
    "\n",
    "report = classification_report(true_labels, \n",
    "              predicted_labels, \n",
    "              digits=2\n",
    ")\n",
    "\n",
    "print(report)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Score Example Sentences\n",
    "Finally, we test the model on some random input sentences."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING:root:Token lists with length > 512 will be truncated\n",
      "Scoring: 100%|██████████| 1/1 [00:00<00:00, 25.31it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      " Is it true that Jane works at Microsoft?\n",
      "       tokens labels\n",
      "0          Is      O\n",
      "1          it      O\n",
      "2        true      O\n",
      "3        that      O\n",
      "4        Jane  I-PER\n",
      "5       works      O\n",
      "6          at      O\n",
      "7  Microsoft?  I-ORG\n",
      "\n",
      " Joe now lives in Copenhagen.\n",
      "        tokens labels\n",
      "0          Joe  I-PER\n",
      "1          now      O\n",
      "2        lives      O\n",
      "3           in      O\n",
      "4  Copenhagen.  I-LOC\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "# test\n",
    "sample_text = [    \n",
    "    \"Is it true that Jane works at Microsoft?\",\n",
    "    \"Joe now lives in Copenhagen.\"\n",
    "]\n",
    "sample_tokens = [x.split() for x in sample_text]\n",
    "\n",
    "sample_dataset = processor.preprocess(\n",
    "    text=sample_tokens,\n",
    "    max_len=MAX_SEQ_LENGTH,\n",
    "    labels=None,\n",
    "    label_map=label_map,\n",
    "    trailing_piece_tag=TRAILING_PIECE_TAG,\n",
    ")\n",
    "sample_dataloader = dataloader_from_dataset(\n",
    "    sample_dataset, batch_size=BATCH_SIZE, num_gpus=None, shuffle=False, distributed=False\n",
    ")\n",
    "preds = model.predict(\n",
    "        test_dataloader=sample_dataloader,\n",
    "        num_gpus=None,\n",
    "        verbose=True\n",
    ")\n",
    "predicted_labels = model.get_predicted_token_labels(\n",
    "    predictions=preds,\n",
    "    label_map=label_map,\n",
    "    dataset=sample_dataset\n",
    ")\n",
    "\n",
    "for i in range(len(sample_text)):\n",
    "    print(\"\\n\", sample_text[i])\n",
    "    print(pd.DataFrame({\"tokens\": sample_tokens[i] , \"labels\":predicted_labels[i]}))  "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## For Testing"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/scrapbook.scrap.json+json": {
       "data": 0.77,
       "encoder": "json",
       "name": "precision",
       "version": 1
      }
     },
     "metadata": {
      "scrapbook": {
       "data": true,
       "display": false,
       "name": "precision"
      }
     },
     "output_type": "display_data"
    },
    {
     "data": {
      "application/scrapbook.scrap.json+json": {
       "data": 0.82,
       "encoder": "json",
       "name": "recall",
       "version": 1
      }
     },
     "metadata": {
      "scrapbook": {
       "data": true,
       "display": false,
       "name": "recall"
      }
     },
     "output_type": "display_data"
    },
    {
     "data": {
      "application/scrapbook.scrap.json+json": {
       "data": 0.79,
       "encoder": "json",
       "name": "f1",
       "version": 1
      }
     },
     "metadata": {
      "scrapbook": {
       "data": true,
       "display": false,
       "name": "f1"
      }
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "report_splits = report.split('\\n')[-2].split()\n",
    "\n",
    "sb.glue(\"precision\", float(report_splits[2]))\n",
    "sb.glue(\"recall\", float(report_splits[3]))\n",
    "sb.glue(\"f1\", float(report_splits[4]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python (nlp_gpu)",
   "language": "python",
   "name": "nlp_gpu"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
